//-*-C++-*-

/***************************************************************************
 *
 *   Copyright (C) 2024 by Will Gauvin
 *   Licensed under the Academic Free License version 2.1
 *
 ***************************************************************************/

#include <assert.h>

#include "dsp/RescaleCUDA.h"

#define FULLMASK 0xFFFFFFFF

using namespace std;

static constexpr unsigned nthreads = 1024;
static constexpr unsigned warp_size = 32;

void check_error_stream(const char *, cudaStream_t);

CUDA::RescaleEngine::RescaleEngine(cudaStream_t _stream)
{
  stream = _stream;
  scratch = new dsp::Scratch;

  d_freq_total = nullptr;
  d_freq_totalsq = nullptr;

  d_scale = h_scale = nullptr;
  d_offset = h_offset = nullptr;
}

CUDA::RescaleEngine::~RescaleEngine()
{
  scratch = nullptr;

  cudaFree(d_freq_total);
  d_freq_total = nullptr;

  cudaFree(d_freq_totalsq);
  d_freq_totalsq = nullptr;

  cudaFree(d_scale);
  d_scale = nullptr;

  cudaFree(d_offset);
  d_offset = nullptr;

  // Free up host allocated memory
  cudaFreeHost(d_freq_total);
  d_freq_total = nullptr;

  cudaFreeHost(d_freq_totalsq);
  d_freq_totalsq = nullptr;

  cudaFreeHost(h_scale);
  h_scale = nullptr;

  cudaFreeHost(h_offset);
  h_offset = nullptr;

  cudaFreeHost(h_offset);
  h_offset = nullptr;
}

void CUDA::RescaleEngine::init(const dsp::TimeSeries *input, uint64_t _nsample, bool _exact, bool _constant_offset_scale)
{
  gpu_config.init();

  nsample = _nsample;
  exact = _exact;
  constant_offset_scale = _constant_offset_scale;

  npol = input->get_npol();
  ndim = input->get_ndim();
  nchan = input->get_nchan();

  data_size_bytes = npol * nchan * sizeof(float);
  auto host_freq_size = npol * nchan * sizeof(double);

  cudaError_t error;

  // allocate and set device memory
  error = cudaMalloc(&d_freq_total, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMalloc d_freq_total failed");

  error = cudaMemsetAsync(d_freq_total, 0, data_size_bytes, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMemsetAsync d_freq_total failed");

  error = cudaMalloc(&d_freq_totalsq, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMalloc d_freq_totalsq failed");

  error = cudaMemsetAsync(d_freq_totalsq, 0, data_size_bytes, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMemsetAsync d_freq_totalsq failed");

  error = cudaMalloc(&d_scale, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMalloc d_scale failed");

  error = cudaMemsetAsync(d_scale, 0, data_size_bytes, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMemsetAsync d_scale failed");

  error = cudaMalloc(&d_offset, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA:RescaleEnging::init", "cudaMalloc d_offset failed");

  error = cudaMemsetAsync(d_offset, 0, data_size_bytes, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA:RescaleEnging::init", "cudaMemsetAsync d_offset failed");

  // allocate host memory
  error = cudaMallocHost(&h_scale, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA:RescaleEnging::init", "cudaMallocHost h_scale failed");

  error = cudaMallocHost(&h_offset, data_size_bytes);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA:RescaleEnging::init", "cudaMallocHost h_offset failed");

  error = cudaMallocHost(&h_freq_total, host_freq_size);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMallocHost h_freq_total failed");

  error = cudaMallocHost(&h_freq_totalsq, host_freq_size);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::init", "cudaMallocHost h_freq_total failed");

  first_integration = true;
}

/*
 Compute a sum of a float across a warp.

 This is a utility kernel to reduce the sum of a value across a warp.

 @param val the value for a given thread within a warp
 @returns the value summed across all threads in the warp.
 */
__inline__ __device__ float rescale_warp_reduce_sum(float val)
{
  for (int offset = warpSize / 2; offset > 0; offset /= 2)
  {
#if (__CUDA_ARCH__ >= 300)
#if (__CUDACC_VER_MAJOR__ >= 9)
    val += __shfl_down_sync(FULLMASK, val, offset);
#else
    val += __shfl_down(val, offset);
#endif
#endif
  }
  return val;
}

/*
  Compute a sum of a value across a whole block.

  This is utility kernel to help getting a sum of a value across a block.

  @param val the value for a given thread within a block
  @returns the value summed across all threads in the block.
 */
__inline__ __device__ float rescale_block_reduce_sum(float val)
{
  // shared mem for 32 partial sums
  __shared__ float shared[32];

  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  // each warp performs partial reduction
  val = rescale_warp_reduce_sum(val);

  // write reduced value to shared memory
  if (lane == 0)
    shared[wid] = val;

  // wait for all partial reductions
  __syncthreads();

  // read from shared memory only if that warp existed
  val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;

  // final reduce within first warp
  if (wid == 0)
    val = rescale_warp_reduce_sum(val);

  return val;
}

/*
  Calculate the offsets, scales, freq total and freq squared totals for TFP-ordered data.

  @param in_ptr base address of TimeSeries from where data statitics are to be calculated from
  @param freq_total the base address of where to output the frequency totals are stored in FP ordering.
  @param freq_totalsq the base address of where the output of the frequency squared totals are stored in FP order.
  @param offset the base address of where the computed offsets are stored in FP order.
  @param scale the base address of where the computed scales are stored in FP order.
  @param ndat the total number of time samples to be used to calculate the statistics.
  @param nchan the number of channels that the input timeseries has.
  @param npol the number of polarizations that the input timeseries has.
  @param ndim the number of dimensions the samples in the input timeseries has.
  @param recip the scale factor used when calculating the mean.

  The scale is 1.0/sqrt(variance) while the offset is the negative value of the mean.  This is consistent with
  the CPU implementation.

  recip is typically 1.0 / (ndat * ndim)

  number of CUDA blocks - typically ceil(nchan * npol / warpSize)
  each warp processes exactly 1 channel+pol (ichanpol) combination and each block of 1024 threads will process
  warpSize ichanpols.

  warp_idx = threadIdx.x % warpSize - the index of the current thread within a warp
  warp_num = threadIdx.x / warpSize - a number representing which warp within block the thread belongs to

  ichanpol = (blockIdx.x * warpSize) + warp_num - which channel + pol combination the thread is working on.
*/
__global__ void rescale_calc_offset_scale_tfp(const float *in_ptr, float *freq_total, float *freq_totalsq, float *offset, float *scale, unsigned ndat, unsigned nchan, unsigned npol, unsigned ndim, float recip)
{
  const unsigned nchanpol = nchan * npol;
  const unsigned warp_idx = threadIdx.x % warpSize;
  const unsigned warp_num = threadIdx.x / warpSize;

  // each warp processes 1 channel, each block of 1024 processes 32 channels
  unsigned ichanpol = (blockIdx.x * warpSize) + warp_num;

  // the sample offset for this thread
  uint64_t idx = ((warp_idx * nchanpol) + ichanpol) * ndim;

  float freq_total_thread = 0.0;
  float freq_totalsq_thread = 0.0;

  if (ichanpol < nchanpol)
  {
    const uint64_t warp_stride = nchanpol * ndim * warpSize;

    // process all of the samples for this chan/pol
    for (uint64_t idat = warp_idx; idat < ndat; idat += warpSize)
    {
      for (unsigned idim = 0; idim < ndim; idim++)
      {
        const float in_val = in_ptr[idx + idim];
        freq_total_thread += in_val;
        freq_totalsq_thread += (in_val * in_val);
      }
      idx += warp_stride;
    }

    // now reduce across the warp
    freq_total_thread = rescale_warp_reduce_sum(freq_total_thread);

    freq_totalsq_thread = rescale_warp_reduce_sum(freq_totalsq_thread);
    __syncthreads();

    // store freq_total, freq_totalsq, offset and scale in FP order
    if (warp_idx == 0)
    {
      freq_total[ichanpol] = freq_total_thread;
      freq_totalsq[ichanpol] = freq_totalsq_thread;

      const float mean = freq_total_thread * recip;
      const float variance = freq_totalsq_thread * recip - mean * mean;

      offset[ichanpol] = -mean;
      if (variance <= 0.0)
        scale[ichanpol] = 1.0;
      else
        scale[ichanpol] = 1.0 / sqrt(variance);
    }
  }
}

/*
  Apply the current offset and scales to a TFP-ordered Timeseries and write it to an output TFP-ordered Timeseries.

  @param in base address of TimeSeries that should be rescaled.
  @param out the base address of the TimeSeries that rescaled values should be written to.
  @param offset the base address of where the computed offsets are stored in FP order.
  @param scale the base address of where the computed scales are stored in FP order.
  @param nchan the number of channels that the input timeseries has.
  @param npol the number of polarizations that the input timeseries has.
  @param ndim the number of dimensions the samples in the input timeseries has.

  Kernel assumes each CUDA block processes exactly 1 time sample.

  out[idx] = (in[idx] + offset[ichanpol]) * scale[ichanpol]
*/
__global__ void rescale_apply_offset_scale_tfp(const float *in,
                                               float *out,
                                               const float *offset, const float *scale,
                                               const unsigned nchan, const unsigned npol, const unsigned ndim)
{
  // each block will process 1 sample
  const uint64_t dat_offset = blockIdx.x * nchan * npol * ndim;

  // process all the channels in this block
  for (unsigned ichan = threadIdx.x; ichan < nchan; ichan += blockDim.x)
  {
    // input and output offsets
    uint64_t idx = dat_offset + (ichan * npol * ndim);

    for (unsigned ipol = 0; ipol < npol; ipol++)
    {
      const unsigned ichanpol = ichan * npol + ipol;
      const float ichanpol_offset = offset[ichanpol];
      const float ichanpol_scale = scale[ichanpol];
      for (unsigned idim = 0; idim < ndim; idim++)
      {
        // convert to 0 mean and unit variance
        out[idx] = (in[idx] + ichanpol_offset) * ichanpol_scale;
        idx++;
      }
    }
  }
}

/*
  Calculate the offsets, scales, freq total and freq squared totals for FPT-ordered data.

  @param in_ptr base address of TimeSeries from where data statitics are to be calculated from
  @param chanpol_stride the stride in data between a channel and a polarisation.
  @param freq_total the base address of where to output the frequency totals are stored in FP ordering.
  @param freq_totalsq the base address of where the output of the frequency squared totals are stored in FP order.
  @param offset the base address of where the computed offsets are stored in FP order.
  @param start_dat which time sample to start from, usually 0 but may not be.
  @param ndat the total number of time samples to be used to calculate the statistics.
  @param ndim the number of dimensions the samples in the input timeseries has.
  @param recip the scale factor used when calculating the mean.

  The scale is 1.0/sqrt(variance) while the offset is the negative value of the mean.  This is consistent with
  the CPU implementation.

  recip is typically 1.0 / (ndat * ndim)

  number of CUDA blocks - typically nchan * npol
  ichanpol = blockIdx.x

  each warp processes exactly 1 channel+pol (ichanpol) combination and each block of 1024 threads will process
  warpSize ichanpols.
*/
__global__ void rescale_calc_offset_scale_fpt(const float *in, unsigned chanpol_stride,
                                              float *freq_total, float *freq_totalsq,
                                              float *offset, float *scale,
                                              uint64_t start_dat, uint64_t ndat, unsigned ndim, float recip)
{
  const unsigned ichanpol = blockIdx.x;
  const uint64_t chanpol_offset = ichanpol * chanpol_stride;

  float freq_total_thread = 0;
  float freq_totalsq_thread = 0;

  const float *in_ptr = in + chanpol_offset;

  for (uint64_t idat = threadIdx.x; idat < ndat; idat += blockDim.x)
  {
    uint64_t idat_idx = (idat + start_dat) * ndim;
    for (unsigned idim = 0; idim < ndim; idim++)
    {
      const float in_val = in_ptr[idat_idx + idim];
      freq_total_thread += in_val;
      freq_totalsq_thread += (in_val * in_val);
    }
  }

  // sum across block
  freq_total_thread = rescale_block_reduce_sum(freq_total_thread);

  // force a sync here so since shared memory is shared in the 2 reductions
  __syncthreads();

  // sum across block
  freq_totalsq_thread = rescale_block_reduce_sum(freq_totalsq_thread);

  __syncthreads();

  if (threadIdx.x == 0)
  {
    freq_total[ichanpol] = freq_total_thread;
    freq_totalsq[ichanpol] = freq_totalsq_thread;

    const float mean = freq_total_thread * recip;
    const float variance = freq_totalsq_thread * recip - mean * mean;

    offset[ichanpol] = -mean;
    if (variance <= 0.0)
      scale[ichanpol] = 1.0;
    else
      scale[ichanpol] = 1.0 / sqrt(variance);
  }
}

/*
  Apply the current offset and scales to a FPT-ordered Timeseries and write it to an output FPT-ordered Timeseries.


  Calculate the offsets, scales, freq total and freq squared totals for FPT-ordered data.

  @param in base address of TimeSeries that should be rescaled.
  @param out the base address of the TimeSeries that rescaled values should be written to.
  @param chanpol_stride the stride in data between a channel and a polarisation.
  @param offset the base address of where the computed offsets are stored in FP order.
  @param start_dat which time sample to start from, usually 0 but may not be.
  @param ndat the total number of time samples to be used to calculate the statistics.
  @param ndim the number of dimensions the samples in the input timeseries has.

  out[idx] = (in[idx] + offset[ichanpol]) * scale[ichanpol]

*/
__global__ void rescale_apply_offset_scale_fpt(const float *in, float *out,
                                               uint64_t chanpol_stride,
                                               float *offset, float *scale,
                                               uint64_t start_dat, uint64_t ndat, unsigned ndim)
{
  const unsigned ichanpol = blockIdx.x;
  const uint64_t chanpol_offset = ichanpol * chanpol_stride;

  const float *in_ptr = in + chanpol_offset;
  float *out_ptr = out + chanpol_offset;

  const float ichanpol_offset = offset[ichanpol];
  const float ichanpol_scale = scale[ichanpol];

  for (uint64_t idat = threadIdx.x; idat < ndat; idat += blockDim.x)
  {
    uint64_t idat_idx = (idat + start_dat) * ndim;
    for (unsigned idim = 0; idim < ndim; idim++)
    {
      out_ptr[idat_idx + idim] = (in_ptr[idat_idx + idim] + ichanpol_offset) * ichanpol_scale;
    }
  }
}

void CUDA::RescaleEngine::transform(const dsp::TimeSeries *input, dsp::TimeSeries *output)
{
  cudaError_t error;
  const auto nchanpol = nchan * npol;
  const auto input_ndat = input->get_ndat();

  auto calculate_offset_scale = first_integration || !constant_offset_scale;

  // loop until we have nsamples
  // if using exact we know from the buffer we will have exactly nsample
  // we may have multiple nsamples in the input ndat which means
  auto num_iters = exact ? 1 : (input_ndat + nsample - 1) / nsample;

  if (dsp::Operation::verbose)
    cerr << "CUDA::RescaleEngine::transform input->get_order()=" << input->get_order()
         << ", nsample=" << nsample << ", ndim=" << ndim << ", npol=" << npol << ", nchan="
         << nchan << ", nchanpol=" << nchanpol << ", input_ndat=" << input_ndat << ", num_iters=" << num_iters
         << ", calculate_offset_scale=" << calculate_offset_scale << endl;

  for (auto iter = 0; iter < num_iters; iter++)
  {
    auto start_dat = iter * nsample;
    auto end_dat = start_dat + nsample;
    if (end_dat > input_ndat)
      end_dat = input_ndat;

    // this should be nsample or less
    auto nsamp = end_dat - start_dat;
    if (dsp::Operation::verbose)
      cerr << "CUDA::RescaleEngine::transform nsamp=" << nsamp << endl;

    if (nsamp == 0)
      break;

    if (calculate_offset_scale)
    {
      const float recip = 1.0 / static_cast<float>(nsamp * ndim);

      if (dsp::Operation::verbose)
        cerr << "CUDA::RescaleEngine::transform calculating offset_scale with recip=" << recip << endl;

      // memset sums to 0
      error = cudaMemsetAsync(d_freq_total, 0, data_size_bytes, stream);
      if (error != cudaSuccess)
        throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemsetAsync d_freq_total failed");

      error = cudaMemsetAsync(d_freq_totalsq, 0, data_size_bytes, stream);
      if (error != cudaSuccess)
        throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemsetAsync d_freq_totalsq failed");

      // perform calc sums
      switch (input->get_order())
      {
      case dsp::TimeSeries::OrderTFP:
      {
        const float *in_ptr = input->get_dattfp();
        uint64_t ptr_offset = start_dat * nchanpol * ndim;

        auto nblocks = nchanpol / warp_size;
        if (nchanpol % warp_size != 0)
        {
          // cerr << "Warning nchanpol " << nchanpol << " not a multiple of 32" << endl;
          nblocks++;
        }

        if (dsp::Operation::verbose)
          cerr << "CUDA::RescaleEngine::transform calling rescale_calc_offset_scale_tfp nblocks=" << nblocks
               << ", nthreads=" << nthreads << ", ptr_offset=" << ptr_offset << endl;

        rescale_calc_offset_scale_tfp<<<nblocks, nthreads, 0, stream>>>(in_ptr + ptr_offset, d_freq_total, d_freq_totalsq, d_offset, d_scale, nsamp, nchan, npol, ndim, recip);
        if (dsp::Operation::record_time || dsp::Operation::verbose)
          check_error_stream("CUDA::RescaleEngine::RescaleEngine::transform rescale_calc_offset_scale_tfp", stream);

        break;
      }
      case dsp::TimeSeries::OrderFPT:
      {
        unsigned nblocks = nchanpol;
        const float *first_chanpol = input->get_datptr(0, 0);
        uint64_t chanpol_stride = 0;
        if (npol == 1 && nchan > 1)
        {
          const float *next_chanpol = input->get_datptr(1, 0);
          chanpol_stride = next_chanpol - first_chanpol;
        }
        else if (npol > 1)
        {
          const float *next_chanpol = input->get_datptr(0, 1);
          chanpol_stride = next_chanpol - first_chanpol;
        }

        if (dsp::Operation::verbose)
          cerr << "CUDA::RescaleEngine::transform calling rescale_calc_offset_scale_fpt nblocks=" << nblocks
               << ", nthreads=" << nthreads << ", chanpol_stride=" << chanpol_stride << endl;

        rescale_calc_offset_scale_fpt<<<nblocks, nthreads, 0, stream>>>(first_chanpol, chanpol_stride,
                                                                        d_freq_total, d_freq_totalsq,
                                                                        d_offset, d_scale,
                                                                        start_dat, nsamp, ndim, recip);

        if (dsp::Operation::record_time || dsp::Operation::verbose)
          check_error_stream("CUDA::RescaleEngine::RescaleEngine::transform rescale_calc_offset_scale_fpt", stream);

        break;
      }
      }
    }

    // perform scaling
    switch (input->get_order())
    {
    case dsp::TimeSeries::OrderTFP:
    {
      const float *in = input->get_dattfp();
      uint64_t ptr_offset = start_dat * nchanpol * ndim;
      unsigned nblocks = nsamp;

      if (dsp::Operation::verbose)
        cerr << "CUDA::RescaleEngine::transform calling rescale_apply_offset_scale_tfp nblocks=" << nblocks
             << ", nthreads=" << nthreads << ", ptr_offset=" << ptr_offset << endl;

      rescale_apply_offset_scale_tfp<<<nblocks, nthreads, 0, stream>>>(in + ptr_offset, output->get_dattfp() + ptr_offset,
                                                                       d_offset, d_scale, nchan, npol, ndim);

      if (dsp::Operation::record_time || dsp::Operation::verbose)
        check_error_stream("CUDA::RescaleEngine::RescaleEngine::transform rescale_apply_offset_scale_tfp", stream);

      break;
    }
    case dsp::TimeSeries::OrderFPT:
    {
      unsigned nblocks = nchanpol;
      const float *first_chanpol = input->get_datptr(0, 0);
      uint64_t chanpol_stride = 0;
      if (npol == 1 && nchan > 1)
      {
        const float *next_chanpol = input->get_datptr(1, 0);
        chanpol_stride = next_chanpol - first_chanpol;
      }
      else if (npol > 1)
      {
        const float *next_chanpol = input->get_datptr(0, 1);
        chanpol_stride = next_chanpol - first_chanpol;
      }

      if (dsp::Operation::verbose)
        cerr << "CUDA::RescaleEngine::transform calling rescale_apply_offset_scale_fpt nblocks=" << nblocks
             << ", nthreads=" << nthreads << ", chanpol_stride=" << chanpol_stride << endl;

      rescale_apply_offset_scale_fpt<<<nblocks, nthreads, 0, stream>>>(first_chanpol, output->get_datptr(0, 0),
                                                                       chanpol_stride,
                                                                       d_offset, d_scale,
                                                                       start_dat, nsamp, ndim);
      if (dsp::Operation::record_time || dsp::Operation::verbose)
        check_error_stream("CUDA::RescaleEngine::RescaleEngine::transform rescale_apply_offset_scale_fpt", stream);
    }
    }
  }

  // ensure we don't calculate scale and offset again if we don't need to.
  first_integration = false;

  // allocate enough memory to copy scales, offset, freq_total and freq_totalsq
  float *scratch_data = (float *)scratch->space(4 * data_size_bytes);

  error = cudaMemcpyAsync(scratch_data, d_offset, data_size_bytes, cudaMemcpyDeviceToHost, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemcpyAsync from d_offset to scratch_data failed");

  error = cudaMemcpyAsync(scratch_data + nchanpol, d_scale, data_size_bytes, cudaMemcpyDeviceToHost, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemcpyAsync from d_scale to scratch_data failed");

  error = cudaMemcpyAsync(scratch_data + (2 * nchanpol), d_freq_total, data_size_bytes, cudaMemcpyDeviceToHost, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemcpyAsync from d_freq_total to scratch_data failed");

  error = cudaMemcpyAsync(scratch_data + (3 * nchanpol), d_freq_totalsq, data_size_bytes, cudaMemcpyDeviceToHost, stream);
  if (error != cudaSuccess)
    throw Error(FailedCall, "CUDA::RescaleEngine::transform", "cudaMemcpyAsync from d_freq_totalsq to scratch_data failed");

  // this will perform a cudaStreamSynchronize and handle the error
  check_error_stream("CUDA::RescaleEngine::RescaleEngine::transform stream sync", stream);

  for (auto ipol = 0; ipol < npol; ipol++)
  {
    for (auto ichan = 0; ichan < nchan; ichan++)
    {
      auto ichanpol = ichan * npol + ipol;
      auto ipolchan = ipol * nchan + ichan;
      h_offset[ipolchan] = scratch_data[ichanpol];
      h_scale[ipolchan] = scratch_data[ichanpol + nchanpol];
      h_freq_total[ipolchan] = scratch_data[ichanpol + 2 * nchanpol];
      h_freq_totalsq[ipolchan] = scratch_data[ichanpol + 3 * nchanpol];
    }
  }

  if (dsp::Operation::verbose)
    cerr << "CUDA::RescaleEngine::transform exiting" << endl;
}

const float *CUDA::RescaleEngine::get_offset(unsigned ipol) const
{
  assert(ipol < npol);
  auto idx = ipol * nchan;
  return &h_offset[idx];
}

const float *CUDA::RescaleEngine::get_scale(unsigned ipol) const
{
  assert(ipol < npol);
  auto idx = ipol * nchan;
  return &h_scale[idx];
}

const double *CUDA::RescaleEngine::get_freq_total(unsigned ipol) const
{
  assert(ipol < npol);
  auto idx = ipol * nchan;
  return &h_freq_total[idx];
}

const double *CUDA::RescaleEngine::get_freq_squared_total(unsigned ipol) const
{
  assert(ipol < npol);
  auto idx = ipol * nchan;
  return &h_freq_totalsq[idx];
}
